跳到主要内容

数据结构 Trie Tree 前缀树 字典树

前缀树是什么?

先提出问题:

一个字符串类型的数组 arr1,另一个字符串类型的数组 arr2。arr2 中有哪些字符,是 arr1 中出现的?请打印。arr2 中有哪些字符,是作为 arr1 中某个字符串前缀出现的? 请打印。 arr2 中有哪些字符,是作为 arr1 中某个字符串前缀出现的?请 打印 arr2 中出现次数最大的前缀

例一:

好比假设有 b,abc,abd,bcd,abcd,efg,hii 这 6 个单词,那我们创建 trie 树就得到

例二:

注意:节点上是不存值的,这些值都是存在路径上面

代码实现

节点的设计

public static class TrieNode {
public int pass; // 表示加字符时这个节点被通过了几次
public int end; // e 表示以他为结尾点的个数

// 因为这里默认输入 26 位小写字母,所以直接用数组,
// 如果不确定输入的字符可以使用 HashMap
public TrieNode[] nexts; // 相当于存储了 26 条路径

public TrieNode() {
pass = 0;
end = 0;
// nexts[0] == null 没有走向 'a' 的路
// nexts[0] != null 有走向 'a' 的路
// ...
// nexts[25] != null 有走向 'z' 的路
nexts = new TrieNode[26];
}
}

注意看这里的 pass 和 end

通过这两个参数就可以快速知道有某个字符串出现的次数了(看尾节点 pass 的次数),同时也可以知道是否存在某个字符串(看尾部 end 是否不为 0)

前缀树代码编写

    /**
* 前缀树
*/
public static class Trie {
private TrieNode root;

public Trie() {
this.root = new TrieNode();
}

public void insert(String word) {
if (word == null)
return;
char[] cs = word.toCharArray();
TrieNode node = root;
node.pass++;
int index = 0;
for (int i = 0; i < cs.length; i++) {
index = cs[i] - 'a';
if (node.nexts[index] == null) {
node.nexts[index] = new TrieNode();
}
node = node.nexts[index];
node.pass++;
}
node.end++;
}

public void delete(String word) {
if (search(word) == 0) return;
char[] cs = word.toCharArray();
TrieNode node = root;
node.pass--;
int index = 0;
for (int i = 0; i < cs.length; i++) {
index = cs[i] - 'a';
if (--node.nexts[index].pass == 0) {
node.nexts[index] = null;
return;
}
node = node.nexts[index];
}
node.end--;
}

/**
* 搜索这个 word 出现几次
*/
public int search(String word) {
if (word == null)
return 0;
char[] cs = word.toCharArray();
TrieNode node = root;
int index = 0;
for (int i = 0; i < cs.length; i++) {
index = cs[i] - 'a';
if (node.nexts[index] == null) {
return 0;
}
node = node.nexts[index];
}
return node.end;
}

/**
* 所有加入的字符串中,有几个以 pre 这个字符串作为前缀的
*/
public int prefixNumber(String pre) {
if (pre == null)
return 0;
char[] cs = pre.toCharArray();
TrieNode node = root;
int index = 0;
for (int i = 0; i < cs.length; i++) {
index = cs[i] - 'a';
if (node.nexts[index] == null) {
return 0;
}
node = node.nexts[index];
}
return node.pass;
}
}

编写测试代码

    public static void main(String[] args) {
Trie trie = new Trie();
trie.insert("hello");
trie.insert("helloword");
System.out.println(trie.search("hello"));
System.out.println(trie.prefixNumber("hello"));
trie.delete("hello");
System.out.println(trie.prefixNumber("hello"));
}